[WIP] Feat/gpt oss example#63
Conversation
Add a NKIPy example for OpenAI's gpt-oss MoE models (gpt-oss-20b / 120b), mirroring the qwen3 example structure. The implementation is fully config-driven, so both sizes share one codebase. gpt-oss-specific handling: - MXFP4 experts dequantized to bf16 at prep time - interleaved gate/up de-interleaved at prep time - clamped SwiGLU with gate_up/down biases - per-head attention sinks + QKV/O biases (no QK-norm) - alternating sliding-window / full attention (one kernel per type) - YaRN RoPE (inv_freq precomputed from HF config) - router with top-k-then-softmax and router bias Validated against HF on trn2 (TP=4): every generated token matches HF's argmax or a bf16-resolution tie.
Implements parallel-drafting P-EAGLE (arXiv 2602.01469) on top of the gpt-oss base model for speculative decoding on Trainium. Components added (examples/models/gpt_oss/eagle/): - config.py: EagleConfig for the 4-layer P-EAGLE drafter (llama3 RoPE, fc fusion, mask_hidden/ptd_token_id, d2t vocab map) - tensor_preparation.py: convert P-EAGLE checkpoint to x@W form (replicated) - kernels/drafter.py: parallel-drafting forward - K tokens in one pass via NTP (real hidden) + MTP (mask_hidden) positions with cross-depth mask - kernels/drafter_layer.py: EAGLE-3 fusion midlayer + plain Llama layers - kernels/verify.py: multi-position greedy argmax for verification - drafter_model.py: device-side drafter model + compile - speculate.py: full speculation loop (prefill → draft → verify → accept) Base model changes: - config.py: added aux_layers config + default_aux_layers() for EAGLE-3 taps - gpt_oss.py: run_prefill() now optionally captures pre-layer hidden states at the 3 EAGLE-3 tap layers (2, L/2, L-3) - kernels/attention.py: generalized decode path to support seq_len>1 (for the multi-token verify pass) via query_pos = start_pos + arange(seq_len) Status: functionally correct (lossless greedy output verified against HF). Acceptance length is below the paper's reported ~3.3 — under investigation (likely a hidden-state position/timing issue in the draft-verify loop seeding).
…yers Switch aux capture to post-layer (output of tap layers 2/12/21) based on HF validation showing the drafter predicts correctly with HF's hidden states at hs[3]/hs[13]/hs[22] (output of layers 2/12/21). Note: acceptance length remains low (~1.0) due to numerical divergence between nkipy's Neuron-compiled target and the HF CPU reference the drafter was trained against. The drafter kernel is mathematically correct (validated against independent torch reference) and correctly predicts the target when fed exact HF hidden states. The gap is an implementation-coupling issue inherent to EAGLE-style speculation.
Key findings from the P-EAGLE paper (Figure 2, Figure 3, Section 3):
1. The drafter maintains its own KV cache across the full context
(prompt + all accepted tokens). At each draft step, K positions
attend to the FULL accumulated cache.
2. The attention mask is GROUP-CAUSAL: all K positions see the full
cache (group 0), but within the K positions the NTP (group 1)
and MTP (group 2+) positions use cross-depth causality — MTP
positions cannot attend to positions at the same or later depth.
3. The NTP pair is (emb(t_n), hidden_after_processing_t_{n-1}),
predicting t_{n+1}. The hidden is one step behind the embedding.
This commit adds:
- drafter_cpu.py: CPU reference drafter with full KV cache and
standard causal attention (working infrastructure, mask needs
the group-causal refinement for MTP positions)
- Fixes hidden state capture to post-layer (output of tap layers)
- Adds peagle_aux_layers config method
Status: KV cache infrastructure correct, still needs the group-causal
mask refinement for the MTP positions within the K-wide draft window.
Root-caused the low acceptance length (~1.4 vs the card's 3.30-3.80 at K=7) on GPU by running the identical checkpoint through vLLM's eagle3 parallel-drafting path, capturing its drafter I/O, and reproducing it with a standalone PyTorch reference (cosine 0.9999, 100% draft-token match). Three bugs plus a prompt-formatting issue: 1. Context-blind drafting (dominant): speculate.py drove DrafterModel (kernels/drafter.py), which runs only the K draft positions under a (K,K) cross-depth mask with no prefill and no KV cache, so the MTP slots never saw the prompt. Rewired speculate.py to use the KV-cached DrafterCPU: prefill the drafter on the prompt (EAGLE +1 shift), then each step roll the cache back to the last accepted position and run [newly-accepted tokens | K-1 ptd slots] in one parallel forward attending to the full context. 2. rollback() truncated the wrong axis: the cache is (B, n_kv, seq, head_dim) and rollback sliced dim 1 (n_kv) instead of dim 2 (seq), so rejected speculative KV was never discarded and corrupted later steps. 3. Aux tap off-by-one: vLLM's eagle3 default (2, n//2, n-3) captures the residual stream entering those layers; our post-layer capture must shift down one, so default_aux_layers now returns (1, 11, 20). Verified on GPU the drafter's 3 fc chunks equal target layer outputs (1, 11, 20) at cos 1.0. 4. Prompt formatting: the drafter is trained on chat data; raw prompts roughly halve acceptance (GPU, K=7: 3.65 chat vs 1.99 raw). speculate.py now applies the chat template by default (--raw-prompt to opt out). Still produces all K draft tokens in a single forward pass (parallel drafting). Adds test_drafter_cpu.py guarding the rollback/full-context invariants (skips without the checkpoint). Validated against vLLM on GPU; not yet re-validated on Trainium, and the on-device kernels/drafter.py KV-cache port remains follow-up. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
gpt-oss + P-EAGLE — Implementation Status ReportStatus as of the current branch TL;DR
Components
What changed recently (this work)The drafter had a working CPU reference but the on-device path was the old
Port design note: the device drafter keeps static kernel shapes — it commits Also fixed along the way: Validation
Under greedy verification emitted tokens are the target's argmax regardless of Performance (5-prompt sweep, n=128, K=7, TP=4, chat template, identical env)
(Full per-prompt table in Known issues / open work
|
Add p-eagle gpt-oss-20b example.